#!/usr/bin/python -tt
#
# Werner Lustermann, Dominik Neise
# ETH Zurich, TU Dortmund
#
# plotter.py

import numpy as np
import matplotlib.pyplot as plt
import os.path
import sys

# this class was formerly called Plotter in the depricated
# module plotter.py
class SimplePlotter(object):
    """ simple x-y plot """
    def __init__(self, name, x, style = 'b', xlabel='x', ylabel='y'):
        """ initialize the object """
        self.__module__ = 'plotters'
        self.name  = name
        self.fig   = plt.figure()
        self.line, = plt.plot(x, style)
        
        plt.title(name)
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.grid(True)
           
    def __call__(self, ydata):
        """ set ydata of plot """
        plt.figure(self.fig.number)
        plt.ylim( np.min(ydata), np.max(ydata) )
        self.line.set_ydata(ydata)
        plt.draw()
            
class Plotter(object):
    """ simple x-y plot """
    def __init__(self, name, x=None, style = '.:', xlabel='x', ylabel='y', ion=True, grid=True, fname=None):
        """ initialize the object """
        self.__module__ = 'plotters'
        self.name  = name
        self.x = x
        self.style = style
        self.xlabel = xlabel
        self.ylabel = ylabel
        
        #not sure if this should go here
        if ion:
            plt.ion()

        self.figure = plt.figure()
        self.fig_id = self.figure.number
        
        plt.grid(grid)
        self.grid = grid
        self.fname = fname
           
    def __call__(self, ydata, label=None):
        """ set ydata of plot """
        style = self.style
        
        # make acitve and clear
        plt.figure(self.fig_id)
        plt.cla()
        
        # the following if else stuff is horrible,
        # but I want all those possibilities, .... still working on it.
        
        # check if 1Dim oder 2Dim
        ydata = np.array(ydata)
        if ydata.ndim ==1:
            if self.x==None:
                plt.plot(ydata, self.style, label=label)
            else:
                plt.plot(self.x, ydata, self.style, label=label)
        else:
            for i in range(len(ydata)):
                if self.x==None:
                    if label:
                        plt.plot(ydata[i], style, label=label[i])
                    else:
                        plt.plot(ydata[i], style)
                else:
                    if label:
                        plt.plot(self.x, ydata[i], style, label=label[i])
                    else:
                        plt.plot(self.x, ydata[i], style)
        plt.title(self.name)
        plt.xlabel(self.xlabel)
        plt.ylabel(self.ylabel)
        if label:
            plt.legend()
        
        if self.fname != None:
            plt.savefig(self.fname)
        
        plt.grid(self.grid)
        plt.draw()
            
        
class CamPlotter(object):
    """ plotting data color-coded into FACT-camera  """
    def __init__(self, name, ion=True, grid=True, fname=None, map_file_path = 'map_dn.txt', vmin=None, vmax=None, dpi=80, s=25):
        """ initialize the object """
        self.__module__ = 'plotters'
        path = os.path.abspath(__file__)
        path = os.path.dirname(path)
        map_file_path = os.path.join(path, map_file_path)
        if not os.path.isfile(map_file_path):
            print 'not able to find file:', map_file_path
            sys.exit(-2)
        
        self.name  = name
        if ion:
            plt.ion()

        chid, y,x,ye,xe,yh,xh,softid,hardid = np.loadtxt(map_file_path ,unpack=True)

        self.xe = -ye
        self.ye = xe

        self.H = (6,0,30./180.*3.1415926)
        #self.H = (6,0,0./180.*3.1415926)
        
        
        self.figure = plt.figure(figsize=(8, 8), dpi=dpi)
        self.fig_id = self.figure.number
        
        self.grid = grid
        self.fname = fname
        self.vmin = vmin
        self.vmax = vmax
        self.s = s
        
    def __call__(self, data, mask=None):
        # define some shortcuts
        xe = self.xe
        ye = self.ye
        H = self.H
        name = self.name
        grid = self.grid
        vmin = self.vmin
        vmax = self.vmax

        # get the figure, clean it, and set it up nicely.
        # maybe cleaning is not necessary and takes long, but
        # I've got no time to test it at the moment.
        plt.figure(self.fig_id)
        plt.clf()
        self.ax = self.figure.add_subplot(111, aspect='equal')
        self.ax.axis([-22,22,-22,22])
        self.ax.set_title(name)
        self.ax.grid(grid)
        
        # throw data into numpy array for simplicity
        data = np.array(data)
        
        #handle masked case specially
        if mask!= None:
            if len(mask)==0:
                return
            
            elif mask.dtype == bool and data.ndim ==1 and len(mask)==1440:
                length = mask.sum()
                mask = np.where(mask)[0]
                mxe = np.empty( length )
                mye = np.empty( length )
                mdata = np.empty( length )
                for i,chid in enumerate(mask):
                    #print i , chid
                    mxe[i] = xe[chid]
                    mye[i] = ye[chid]
                    mdata[i] = data[chid]
                #print 'mxe', mxe, 'len', len(mxe)
                #print 'mye', mye, 'len', len(mye)
                #print 'mxe', mdata, 'len', len(mdata)

                self.ax.axis([-22,22,-22,22])
                self.ax.set_title(name)
                self.ax.grid(grid)
                # the next line is a stupid hack
                # I plot invisible pixels, so that the axes show look ok.
                # this must be possible differently, but I don't know how...
                
                self.ax.scatter(xe,ye,s=self.s,alpha=0,marker=H)
                
                result = self.ax.scatter(mxe,mye,s=self.s,alpha=1.,
                            c=mdata, marker=H, linewidths=0., vmin=vmin, vmax=vmax)
                self.figure.colorbar( result, shrink=0.8, pad=-0.04 )
                plt.draw()


            elif mask.dtype == int  and data.ndim ==1:
                length = len(mask)
                mxe = np.empty( length )
                mye = np.empty( length )
                mdata = np.empty( length )
                for i,chid in enumerate(mask):
                    mxe[i] = xe[chid]
                    mye[i] = ye[chid]
                    mdata[i] = data[chid]

                self.ax.axis([-22,22,-22,22])
                self.ax.set_title(name)
                self.ax.grid(grid)
                # the next line is a stupid hack
                # I plot invisible pixels, so that the axes look ok.
                # this must be possible differently, but I don't know how...
                self.ax.scatter(xe,ye,s=self.s,alpha=0,marker=H)
                
                result = self.ax.scatter(mxe,mye,s=self.s,alpha=1.,
                            c=mdata, marker=H, linewidths=0., vmin=vmin, vmax=vmax)
                self.figure.colorbar( result, shrink=0.8, pad=-0.04 )
                plt.draw()
                
            else:
                print "there is a mask, but I don't know how to treat it!!!"
                sys.exit(-1)
        else: # i.e. when mask is None
        # handle 1D and 2D case differently
            if data.ndim == 1 and len(data)==1440:
                result = self.ax.scatter(xe,ye,s=self.s,alpha=1,
                            c=data, marker=H, linewidths=0., vmin=vmin, vmax=vmax)
                self.figure.colorbar( result, shrink=0.8, pad=-0.04 )
                plt.draw()
                
            elif data.ndim == 2 and data.shape[0] == 2 and data.shape[1] <=1440:
                # I assume the first row of data, contains the CHIDs 
                # and the 2nd row contains the actual data.
                chids = data[0]
                # check if there are double chids in chids
                if len(chids)!=len(set(chids)):
                    print 'warning: there are doubled chids in input data',
                    print 'you might want to plot something else, but I plot it anyway...'
                    print chids
                data = data[1]
                # now I have to mask the xe, and ye vectors accordingly
                mxe = np.empty( len(chids) )
                mye = np.empty( len(chids) )
                for i,chid in enumerate(chids):
                    mxe[i] = xe[chid]
                    mye[i] = ye[chid]
                
                # check if I did it right
                if len(mxe)!=len(data) or len(mye)!=len(data):
                    print 'the masking did not work:'
                    print 'len(mxe)', len(mxe)
                    print 'len(mye)', len(mye)
                    print 'len(data)', len(data)
                
                self.ax.axis([-22,22,-22,22])
                self.ax.set_title(name)
                self.ax.grid(grid)
                # the next line is a stupid hack
                # I plot invisible pixels, so that the axes show look ok.
                # this must be possible differently, but I don't know how...
                self.ax.scatter(xe,ye,s=25,alpha=0,marker=H)
                result = self.ax.scatter(mxe,mye,s=self.s,alpha=1.,
                            c=data, marker=H, linewidths=0., vmin=vmin, vmax=vmax)
                self.figure.colorbar( result, shrink=0.8, pad=-0.04 )
                plt.draw()
                
                
            else:
                print 'CamPlotter call input data has bad format'
                print 'data.ndim', data.ndim
                print 'data.shape', data.shape
                print 'data:----------------------------------'
                print data
        
        
        

class HistPlotter(object):
    
    def __init__(self, name, bins, range, grid=True, ion=True):
        """ initialize the object """
        self.bins = bins
        self.range = range
        self.name  = name
        self.figure = plt.figure()
        self.fig_id = self.figure.number
        self.grid = grid
        
        if ion:
            plt.ion()
        
    def __call__(self, ydata, label=None, log=False):
        plt.figure(self.fig_id)
        plt.cla()

        bins = self.bins
        range = self.range
        grid = self.grid
        
        ydata = np.array(ydata)
        
        if ydata.ndim > 1:
            ydata = ydata.flatten()
        if label:
            plt.hist(ydata, bins, range, label=label, log=log)
            plt.legend()
        else:
            plt.hist(ydata, bins, range, log=log)
            
        plt.title(self.name)
        
        plt.draw()
            
def _test_SimplePlotter():
    """ test of maintaining two independant plotter instances """
    plt.ion()
    
    x = np.linspace(0., 10.)
    plot1 = SimplePlotter('plot1', x, 'r')
    print 'plot1.fig.number: ', plot1.fig.number
    plot2 = SimplePlotter('plot2', x, 'g.')
    print 'plot2.fig.number: ', plot2.fig.number
    
    plot1(np.sin(x) * 7.)
    plot2(x*x)
    
    raw_input('next')
    
    plot1(np.cos(x) * 3.)
    plot2(x)
    
    raw_input('next')


def _test_Plotter():
    """ test of maintaining two independant plotter instances 
        with different examples for init and call
    """
    x = np.linspace(0., 2*np.pi , 100)
    plot1 = Plotter('plot1', x, 'r.:')
    plot2 = Plotter('plot2')
    
    y1 = np.sin(x) * 7
    plot1(y1)
    
    number_of_graphs_in_plot2 = 3
    no = number_of_graphs_in_plot2  # short form
    
    # this is where you do your analysis...
    y2 = np.empty( (no, len(x)) )   # prepare some space
    y2_labels = []                  # prepare labels
    for k in range(no):
        y2[k] = np.sin( (k+1)*x )
        y2_labels.append('sin(%d*x)' % (k+1) )
        
    # plot the result of your analysis
    plot2(y2, y2_labels)
    raw_input('next')       # do not forget this line, or your graph is lost
    
    plot1(np.cos(x) * 3.)
    plot2.name += ' without labels!!!' # changing titles 'on the fly' is possible
    plot2(y2)
    raw_input('next')       # DO NOT forget 


def _test_CamPlotter():
    """ test of CamPlotter """
    
    c1 = np.array(range(20))
    chids1 = np.empty( len(c1) , dtype=int)
    for i in range(len(chids1)-2):
        chids1[i] = np.random.randint(1440)
    chids1[-1] = 15
    chids1[-2] = 15
    
    c2 = np.linspace(0., 1., num=1440)
    plot1 = CamPlotter('plot1')
    plot2 = CamPlotter('plot2')
    
    plot1( (chids1,c1) )
    plot2(c2)
    raw_input('next')
    
def _test_HistPlotter():
    """ test of the HistPlotter """
    plt.ion()

    data = np.random.randn(1000)
    hp = HistPlotter('test hist plotter',34, (-5,4))
    
    hp(data, 'test-label')
    raw_input('next')

if __name__ == '__main__':
    """ test the class """
    print ' testing SimplePlotter'
    _test_SimplePlotter()
    print ' testing Plotter'
    _test_Plotter()
    print 'testing CamPlotter ... testing what happens if doubled IDs in mask'
    _test_CamPlotter()
    print 'testing basic HistPlotter functionality'
    _test_HistPlotter()
    
